{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 模型训练和特征重要性分析\n",
    "\n",
    "在本步骤中，我们将完成以下流程：\n",
    "\n",
    "1. 加载经过预处理的训练数据\n",
    "2. 使用随机森林算法训练模型\n",
    "3. 提取并保存模型中的特征重要性\n",
    "4. 可视化和保存前 30 个特征的重要性图\n",
    "\n",
    "这些步骤将帮助我们了解在模型中起关键作用的特征，便于在后续分析中进行特征选择和优化。\n"
   ],
   "id": "dc5d1e93f853f635"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 加载必要的库\n",
    "import pandas as pd\n",
    "import joblib\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n"
   ]
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### 导入库\n",
    "在此单元格中，我们导入所需的库以便于数据处理、模型训练和绘图：\n",
    "- `pandas` 用于数据操作\n",
    "- `joblib` 用于模型的保存\n",
    "- `RandomForestClassifier` 从 `sklearn` 中导入，用于训练随机森林模型\n",
    "- `matplotlib.pyplot` 用于绘制特征重要性图表\n"
   ],
   "id": "9a0b3eebe359404e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 加载预处理后的训练数据\n",
    "data_dir = '../data/processed'\n",
    "X_train = pd.read_csv(os.path.join(data_dir, 'X_train.csv'))\n",
    "y_train = pd.read_csv(os.path.join(data_dir, 'y_train.csv')).values.ravel()\n"
   ],
   "id": "a54673612229b28f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### 加载预处理后的训练数据\n",
    "在此单元格中，我们从指定的数据目录 `data_dir` 中加载训练集的数据。\n",
    "- `X_train` 包含训练数据的特征\n",
    "- `y_train` 包含目标标签（`values.ravel()` 用于将标签数据转换为一维数组，以便兼容模型训练）\n"
   ],
   "id": "657e736eab9e2f0d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 训练随机森林模型\n",
    "model = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "model.fit(X_train, y_train)\n"
   ],
   "id": "cd49e675d05578a2"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### 训练随机森林模型\n",
    "在此单元格中，我们创建并训练随机森林模型：\n",
    "- 使用 `RandomForestClassifier` 初始化模型，指定参数 `n_estimators=100`（即使用100棵决策树）和 `random_state=42` 以确保结果可重复\n",
    "- 调用 `fit` 方法，将训练特征 `X_train` 和标签 `y_train` 输入模型进行训练\n"
   ],
   "id": "1ee435366a273ddb"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 提取并保存特征重要性\n",
    "feature_importance = pd.Series(model.feature_importances_, index=X_train.columns)\n",
    "feature_importance = feature_importance.sort_values(ascending=False)\n",
    "feature_importance.to_csv('../data/processed/feature_importance.csv', header=True)\n"
   ],
   "id": "a6d075cb77cc2257"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### 提取并保存特征重要性\n",
    "在此单元格中，我们提取并保存特征的重要性：\n",
    "- `model.feature_importances_` 提取每个特征的重要性评分\n",
    "- 将特征重要性数据创建为 `pandas Series`，使用特征列名作为索引\n",
    "- 使用 `sort_values` 方法按重要性从高到低排序，并保存到 `feature_importance.csv`\n"
   ],
   "id": "fb6f502574d783dc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 绘制特征重要性图\n",
    "plt.figure(figsize=(10, 6))\n",
    "feature_importance[:30].plot(kind='bar')\n",
    "plt.title(\"Top 30 Feature Importances\")\n",
    "plt.savefig('../data/processed/feature_importance.png')\n",
    "plt.show()\n"
   ],
   "id": "41d084c306739ba8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### 绘制并保存特征重要性图\n",
    "在此单元格中，我们绘制并保存前30个重要特征的柱状图：\n",
    "- 设置图表大小 `figsize=(10, 6)`\n",
    "- 使用 `plot` 方法绘制前30个重要特征的柱状图\n",
    "- 添加图表标题 \"Top 30 Feature Importances\"\n",
    "- 将图表保存为 `feature_importance.png`，并展示出来\n"
   ],
   "id": "2dca49390e1084fa"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
